#!/usr/bin/env python3

import argparse
from gym_minigrid.wrappers import *
from gym_minigrid.window import Window

from envs.minigrid.adversarial import *
from utils.env import make_env


def dummy_event_detector(noise_level):
    true_events = env.get_events()
    all_events = env.env.letter_types
    event_preds = [1.0 if e in true_events else 0.0 for e in all_events]
    event_preds = np.array(event_preds)
    noise = np.random.random(event_preds.shape) * noise_level
    noisy_event_preds = np.clip(event_preds + noise, 0., 1.)
    return noisy_event_preds

def redraw(img):
    if not args.agent_view:
        img = base_env.render(mode='rgb_array', tile_size=args.tile_size)

    window.show_img(img)

def reset():
    if args.seed != -1:
        env.seed(args.seed)

    obs = env.reset()

    redraw(obs)

def step(action):
    if args.rm_update_algo in ('perfect_detector', 'independent_belief'):
        noise_level = 0.0 if args.rm_update_algo == 'perfect_detector' else 0.2
        event_preds = dummy_event_detector(noise_level)
        env.update_rm_beliefs(event_preds, logit=False)
    obs, reward, done, info = env.step(action)
    print('step=%s, reward=%.2f' % (base_env.step_count, reward))

    print(f'current_u_id={env.current_u_id}')
    if args.rm_update_algo in ('perfect_detector', 'independent_belief'):
        print(f'belief_u_dist={env.belief_u_dist}')

    if done:
        print('done!')
        reset()
    else:
        redraw(obs)

def key_handler(event):
    print('pressed', event.key)

    if event.key == 'escape':
        window.close()
        return

    if event.key == 'backspace':
        reset()
        return

    if event.key == 'left':
        step(base_env.actions.left)
        return
    if event.key == 'right':
        step(base_env.actions.right)
        return
    if event.key == 'up':
        step(base_env.actions.forward)
        return
    if event.key == 'down':
        step(base_env.actions.backward)
        return
    if event.key == 't':
        step(base_env.actions.turn180)
        return
    if event.key == 'w':
        step(base_env.actions.wait)
        return

    # Spacebar
    if event.key == ' ':
        step(base_env.actions.toggle)
        return
    if event.key == 'pageup':
        step(base_env.actions.pickup)
        return
    if event.key == 'pagedown':
        step(base_env.actions.drop)
        return

    if event.key == 'enter':
        step(base_env.actions.done)
        return

parser = argparse.ArgumentParser()
parser.add_argument(
    "--env",
    help="gym environment to load",
    default='MiniGrid-MultiRoom-N6-v0'
)
parser.add_argument(
    "--seed",
    type=int,
    help="random seed to generate the environment with",
    default=-1
)
parser.add_argument(
    "--tile_size",
    type=int,
    help="size at which to render tiles",
    default=32
)
parser.add_argument(
    '--agent_view',
    default=False,
    help="draw the agent sees (partially observable view)",
    action='store_true'
)

parser.add_argument(
    "--rm-update-algo",
    default="perfect_rm",
    help="[independent_belief, perfect_rm]"
)

args = parser.parse_args()

# `env` is the (1-level) wrapped minigrid from our code
env = make_env(args.env, args.rm_update_algo)
# `base_env` is the backend minigrid
base_env = env
while not issubclass(type(base_env), MiniGridEnv):
    base_env = base_env.env

window = Window('gym_minigrid - ' + args.env)
window.reg_key_handler(key_handler)

reset()

# Blocking event loop
window.show(block=True)